{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "5325ac27-38ea-47aa-afef-be4ec4f8f4b9",
   "metadata": {},
   "source": [
    "# Auto Merging Retriever\n",
    "\n",
    "In this notebook, we showcase our `AutoMergingRetriever`, which looks at a set of leaf nodes and recursively \"merges\" subsets of leaf nodes that reference a parent node beyond a given threshold. This allows us to consolidate potentially disparate, smaller contexts into a larger context that might help synthesis.\n",
    "\n",
    "You can define this hierarchy yourself over a set of documents, or you can make use of our brand-new text parser: a HierarchicalNodeParser that takes in a candidate set of documents and outputs an entire hierarchy of nodes, from \"coarse-to-fine\"."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7e1316ac-84ca-41d0-80f9-d4ef758e653c",
   "metadata": {},
   "source": [
    "## Load Data\n",
    "\n",
    "Let's first load the Llama 2 paper: https://arxiv.org/pdf/2307.09288.pdf. This will be our test data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80372299-ab32-4ddd-9b88-05c877120c17",
   "metadata": {},
   "outputs": [],
   "source": [
    "!wget --user-agent \"Mozilla\" \"https://arxiv.org/pdf/2307.09288.pdf\" -O \"data/llama2.pdf\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "5f9c5d99-bd0e-4b26-b816-9f5ad29df3c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "from llama_hub.file.pdf.base import PDFReader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "723f1f02-2157-4166-b013-90e627c76530",
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = PDFReader()\n",
    "docs0 = loader.load_data(file=Path(\"./data/llama2.pdf\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff7a8552-f347-45b0-b4a0-4f9b32be57ac",
   "metadata": {},
   "source": [
    "By default, the PDF reader creates a separate doc for each page.\n",
    "For the sake of this notebook, we stitch docs together into one doc. \n",
    "This will help us better highlight auto-merging capabilities that \"stitch\" chunks together later on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a75c4217-ab50-417f-a8ed-3b746a9956c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index import Document\n",
    "\n",
    "doc_text = \"\\n\\n\".join([d.get_content() for d in docs0])\n",
    "docs = [Document(text=doc_text)]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "724fe6f1-80e1-4ac5-bd99-8b9b8d15bddd",
   "metadata": {},
   "source": [
    "## Parse Chunk Hierarchy from Text, Load into Storage\n",
    "\n",
    "In this section we make use of the `HierarchicalNodeParser`. This will output a hierarchy of nodes, from top-level nodes with bigger chunk sizes to child nodes with smaller chunk sizes, where each child node has a parent node with a bigger chunk size.\n",
    "\n",
    "By default, the hierarchy is:\n",
    "- 1st level: chunk size 2048\n",
    "- 2nd level: chunk size 512\n",
    "- 3rd level: chunk size 128\n",
    "\n",
    "\n",
    "We then load these nodes into storage. The leaf nodes are indexed and retrieved via a vector store - these are the nodes that will first be directly retrieved via similarity search. The other nodes will be retrieved from a docstore."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "45e783f5-a323-4f51-ae9a-4b71b00e5e11",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.node_parser import HierarchicalNodeParser, SimpleNodeParser"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "2c3947df-25c2-4254-a3d4-381d136f3f77",
   "metadata": {},
   "outputs": [],
   "source": [
    "node_parser = HierarchicalNodeParser.from_defaults()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2162b309-dfc5-484b-a31c-24f705316f10",
   "metadata": {},
   "outputs": [],
   "source": [
    "nodes = node_parser.get_nodes_from_documents(docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a9b5bc9b-389d-47db-a41c-3eb5b3d38ac5",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "999"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7456b70-1803-4786-86d5-26e202e0f318",
   "metadata": {},
   "source": [
    "Here we import a simple helper function for fetching \"leaf\" nodes within a node list. \n",
    "These are nodes that don't have children of their own."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7299ca7e-09b6-432f-a277-aae9eca0522a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.node_parser import get_leaf_nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "faeb37a8-aea9-4ee8-b6c0-3b2f188d244e",
   "metadata": {},
   "outputs": [],
   "source": [
    "leaf_nodes = get_leaf_nodes(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "7c33b5a8-4d9f-481e-8616-fc8717900159",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "783"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(leaf_nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c36ec940-8af7-45f5-9994-919d57583c24",
   "metadata": {},
   "source": [
    "### Load into Storage\n",
    "\n",
    "We define a docstore, which we load all nodes into. \n",
    "\n",
    "We then define a `VectorStoreIndex` containing just the leaf-level nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "27c8f2cd-3e04-4feb-937b-b9ee33e1c2fd",
   "metadata": {},
   "outputs": [],
   "source": [
    "# define storage context\n",
    "from llama_index.storage.docstore import SimpleDocumentStore\n",
    "from llama_index.storage import StorageContext\n",
    "\n",
    "docstore = SimpleDocumentStore()\n",
    "\n",
    "# insert nodes into docstore\n",
    "docstore.add_documents(nodes)\n",
    "\n",
    "# define storage context (will include vector store by default too)\n",
    "storage_context = StorageContext.from_defaults(docstore=docstore)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "827ece8e-7a4b-4ee1-8ee2-3433d7f2072a",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Load index into vector index\n",
    "from llama_index import VectorStoreIndex\n",
    "\n",
    "base_index = VectorStoreIndex(leaf_nodes, storage_context=storage_context)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "05d84c19-c9ac-4294-a000-264c3c02427b",
   "metadata": {},
   "source": [
    "## Define Retriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "e61682a0-dd3c-400b-8734-35d5d0a98252",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.retrievers.auto_merging_retriever import AutoMergingRetriever"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "f96fd0bc-c6c0-4073-a692-d1803cf4289f",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_retriever = base_index.as_retriever(similarity_top_k=6)\n",
    "retriever = AutoMergingRetriever(base_retriever, storage_context, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "62f655cd-4195-4398-80e5-5aa561982d25",
   "metadata": {},
   "outputs": [],
   "source": [
    "# query_str = \"What were some lessons learned from red-teaming?\"\n",
    "query_str = \"Can you tell me about the key concepts for safety finetuning\"\n",
    "\n",
    "nodes = retriever.retrieve(query_str)\n",
    "base_nodes = base_retriever.retrieve(query_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "77eabc56-2009-4504-8832-b6d857bd43a4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "5"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d482b22-fd38-476b-821f-0c77564815c3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.response.notebook_utils import display_source_node\n",
    "\n",
    "for node in nodes:\n",
    "    display_source_node(node, source_length=10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e4dd58db-4b12-49dc-b42f-8a0ee746f5c9",
   "metadata": {},
   "outputs": [],
   "source": [
    "for node in base_nodes:\n",
    "    display_source_node(node, source_length=10000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "08f62e2c-4def-402e-8904-47f34d12c2fb",
   "metadata": {},
   "source": [
    "## Plug it into Query Engine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5d3ce9ec-f6cd-475b-94fa-3e8df81ab824",
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.query_engine import RetrieverQueryEngine"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f106e1bb-58bc-48bf-a46b-e527339f83c5",
   "metadata": {},
   "outputs": [],
   "source": [
    "query_engine = RetrieverQueryEngine.from_args(retriever)\n",
    "base_query_engine = RetrieverQueryEngine.from_args(base_retriever)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94a85854-ca04-41ed-9f44-b6dce1e513e1",
   "metadata": {},
   "outputs": [],
   "source": [
    "response = query_engine.query(query_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "8b334b7b-fcb8-4057-a418-b8d8c425ad14",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The key concepts for safety fine-tuning include supervised safety fine-tuning and safety RLHF (Reinforcement Learning with Human Feedback). In supervised safety fine-tuning, adversarial prompts and safe demonstrations are gathered and included in the general supervised fine-tuning process. This helps the model align with safety guidelines even before RLHF. Safety RLHF involves collecting human preference data by having annotators write prompts that they believe can elicit unsafe behavior. Multiple model responses are compared, and the safest response is selected according to a set of guidelines. This data is then used to train a safety reward model. These concepts are aimed at mitigating safety risks and improving the safety of language models.\n"
     ]
    }
   ],
   "source": [
    "print(str(response))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "1c38a124-5279-4a43-a2fe-ed2cbce9bd66",
   "metadata": {},
   "outputs": [],
   "source": [
    "base_response = base_query_engine.query(query_str)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5c2910e5-1a45-4de5-8035-5b5a47125d81",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The key concepts for safety fine-tuning include supervised safety fine-tuning and safety RLHF (Reinforcement Learning with Human Feedback). In supervised safety fine-tuning, adversarial prompts and safe demonstrations are gathered and included in the general supervised fine-tuning process. This helps the model align with safety guidelines even before RLHF. Safety RLHF involves observing and generalizing from safe demonstrations in supervised fine-tuning. The model learns to write detailed safe responses, address safety concerns, explain sensitive topics, and provide additional helpful information. These concepts aim to mitigate safety risks and improve the safety of language model models.\n"
     ]
    }
   ],
   "source": [
    "print(str(base_response))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama_index_v2",
   "language": "python",
   "name": "llama_index_v2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
